# -*- coding: utf-8 -*-
"""
The python script to generate Figure 7. 
It includes 1. read hourly air temperature data on lake and rural (nonlake) 
2. Necessary data processing (computing long-term temperature duirnal cycles, converting UTC to LT)
3. Plot 2D maps of lake-nonlake air temperatures during 2019-2023 and changes
of lake-nonlake air temperature between 2096-2100 and 2019-2023

"""
import netCDF4 as nc
import numpy as np
from mpl_toolkits.basemap import Basemap
import gc
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm
import warnings
warnings.filterwarnings("ignore")
Dir ='G:\\SSP5UC_DataPreparation\\deposit\\Hourly\\' # The directory you are working with

# Read two masked arrays on lake and rural land tiles, respectively
# Their maskes are useful to ensure only grids with both lake and rural land tiles are used
FileLake=nc.Dataset(Dir+'b.e21.SSP585UC_IndividualSoil_tasl_Lake_Hourly_2019.nc')
MaskArray_L=FileLake.variables['tasl'][0:24,:,:]
FileRural=nc.Dataset(Dir+'b.e21.SSP585UC_IndividualSoil_tasr_Rural_Hourly_2019.nc')
MaskArray_R=FileRural.variables['tasr'][0:24,:,:]
LonNum=288 # Number of longitudinal grids
LatNum=192 # Number of latitudinal grids

def CalMeanDuirnalCycle(Var):
    """
    CalMeanDuirnalCycle
    Description: a function to calculate the long term mean duirnal cycles
    Input： Var--continuous hourly data during a certain time period in the shape of (hours,LatNum,LonNum)
    Output: VarMean--long-term mean duirnal cycles in the shape of (24,LatNum,LonNum)
     
    """     
    DayCount=int(len(Var[:,0,0])/24) # How many days 
    VarMean=np.empty((24,LatNum,LonNum));VarMean[:] = np.nan; VarMean= np.ma.masked_invalid(VarMean)
    for i in range(0,24): # loop through 0:00-23:00 and calculate the long-term mean values
        VarMean[i,:,:]=(np.nanmean(Var[i:24*(DayCount-1)+i+1:24,:,:],axis=0))    
    return VarMean

def ConvertUTC2LTC(UTCData):
    """ 
    Description: a function to convert data using UTC as data using Local Time
    Input: UTCData--the data in the shape of (24,LatNum,LonNum) using UTC
    Output: LTData--the data in the shape of (24,LatNum,LonNum) using LT
    """ 
    # Here TZLon is a 24*12 helper array. The values(0~287) in TZLon stands for the index on the longitudinal dimension.
    # The grids at the same row indicate longitudinal indexes belonging to the same Timezone. 
    # Their Timezone=row number when row number<12, and Timezone=row numbers-24 when row number>12
    TZLon=np.empty((24, int(LonNum/24)))
    TZInterval=int(LonNum/24) # There are 24 time zones. Every time zone contains LatNum*12 grids
    HalfInterval=int(TZInterval/2)
    for j in range(0,24):
        TZLon[j,:]=np.arange(int(HalfInterval+1+TZInterval*(j-1)),int(HalfInterval+1+TZInterval*(j))) -1 
    TZLon[0,0:HalfInterval]=np.arange(LonNum+1-HalfInterval,LonNum+1)-1
    TZLon=np.int16(TZLon)  
    
    # Initialize LTData as a 24*192*288 array
    LTData=np.empty((24,LatNum,LonNum));LTData[:] = np.nan; LTData= np.ma.masked_invalid(LTData)
    
    # For grids in each time zone, reorder the data such that the data at 12 AM local time is the first, and them 1AM, 2AM...
    for k in range(0,24):
        if k==0: # At time zone zero, make no adjustments
            LTData[0:24,:,LonNum-HalfInterval:LonNum]=UTCData[0:24,:,LonNum-HalfInterval:LonNum]
            LTData[0:24,:,0:HalfInterval]=UTCData[0:24,:,0:HalfInterval]
        elif k in range(1,24): # At time zone zero, make no adjustments
            Temp=24-k
            LTData[0:24,:,TZLon[k,0]:(TZLon[k,TZInterval-1]+1)]=np.ma.concatenate((UTCData[Temp:24,:,TZLon[k,0]:(TZLon[k,TZInterval-1]+1)],UTCData[0:Temp,:,TZLon[k,0]:(TZLon[k,TZInterval-1]+1)]),axis=0)

    # Apply the lake and rural masks just to make sure only grids with both lake and rural tiles are used
    LTData=np.ma.masked_where(np.ma.getmask(MaskArray_R),LTData) 
    LTData=np.ma.masked_where(np.ma.getmask(MaskArray_L),LTData)
    return LTData

def CalDayNightMean(Var):
    """ 
    Description: a function to calculate the daytime and nighttime mean values of a variable
    Input: Var--original data in the shape of (24,LatNum,LonNum) using Local Time
    Output: VarDay--Daytime mean result (8AM-4PM) 
            VarNight--Nighttime mean result (8PM-4AM)    
    """ 
    VarDay=np.nanmean(Var[8:17,:,:],axis=0)
    VarDay=VarDay[np.newaxis,:,:] # Add an addtional dimension so that the data can be stacked on this dimension
    VarNight=np.nanmean(np.ma.concatenate((Var[0:5,:,:],Var[20:24,:,:]),axis=0),axis=0)
    VarNight=VarNight[np.newaxis,:,:]
    return VarDay,VarNight

def Cal_latitudinal_mean(Var):
    """ 
    Description: a function to calculate the latitudinal mean values of a variable
    Input: Var--data in the shape of (LatNum,LonNum)
    Output: LatMeanVar--Latitudinal mean values in the shape of (LatNum)
    """ 
    LatMeanVar=np.zeros(Var[:,1].shape) # Initialize the array of latitudinal mean values
    for i in range(0,192): # For each latitude, calculate the mean values of all grids at this latitude
        if len(Var[i,:].compressed())>4: # Only calculate the mean value when there are more than 4 grids
            LatMeanVar[i]=np.nanmean(Var[i,:])
        else:
            LatMeanVar[i]=np.nan
    LatMeanVar=np.flip(LatMeanVar, axis=0) 
    return LatMeanVar

def PlotContour(TobePlot, Lim1, Lim2, Interval,Title,Colorscheme):
    """ 
    Description: a function to visualize Lake-Nonlake air temperature or any other 2D global data
    Input:  TobePlot--2D data to be visualized
            Lim1,Lim2--The range of the colorbar
            Interval--The interval of the colorbar
            Title-- Figure title
            Colorscheme: python color shceme
    Output: the figure object
    """ 
    # Draw the basenap
    lat = LatNum
    lon = LonNum
    m = Basemap(projection='cyl',lon_0=0.,lat_0=0.,lat_ts=0.,fix_aspect=False,\
                llcrnrlat=-90,urcrnrlat=90,\
                llcrnrlon=-180,urcrnrlon=180.0,\
                rsphere=6371200.,resolution='l',area_thresh=10000)    
    m.drawcountries(linewidth=0.15)
    m.drawmapboundary(fill_color='lightcyan') #
    m.fillcontinents(color='whitesmoke',lake_color='lightcyan')
    m.drawcoastlines(color = '0.15',linewidth=0.5)
    # Draw parallels.
    parallels = np.arange(-90.,100.,30.) 
    m.drawparallels(parallels,labels=[1,0,0,0],fontsize=10,linewidth=0.3)
    # Draw meridians
    meridians = np.arange(0.,360.,60.)
    m.drawmeridians(meridians,labels=[0,0,0,1],fontsize=10,linewidth=0.3)

    ny = lat; nx = lon
    lons, lats = m.makegrid(nx, ny) # get lat/lons of ny by nx evenly space grid.
    x, y = m(lons, lats) # compute map proj coordinates.
    TobePlot2=np.ma.concatenate((TobePlot[:,144:288],TobePlot[:,0:144]),axis=1) # Center the (0°,0°)
    clevsVPDif = np.arange(Lim1,Lim2,Interval) 
    cmap = plt.get_cmap(Colorscheme)
    norm = BoundaryNorm(clevsVPDif, ncolors=cmap.N, clip=False) 
    cs = m.pcolormesh(x,y,TobePlot2,zorder=3,cmap=cmap, norm=norm) 
    return cs
    
def PlotLatitudinalMean(Latmean,linecolor,ylabel):   
    """ 
    Description: a function to plot the latitudinal mean values
    Input: Var--1D latitudinal mean data to be visualized
    """ 
    Lat=np.arange(90,-90,-180/LatNum)
    plt.plot(Latmean,Lat,linecolor,linewidth=1.4,label=ylabel)
    plt.plot([0]*192,Lat,'k--',linewidth=0.7) # A x=0 reference line
    plt.ylim( (-90,90) )
    plt.yticks(np.arange(-90,120, 30),fontsize=10)
    plt.xticks(fontsize=10)

    if ylabel == 'ΔTa':
        plt.xlim( (-3, 2.5) )      
    else:
        plt.xlim( (-1.5, 1) )      
    
# Short name of the targeted variable
Varname='tas'

for year in range(2019,2023):
    Lakefile= nc.Dataset(Dir+'b.e21.SSP585UC_IndividualSoil_'+Varname+'l_Lake_Hourly_'+str(year)+'.nc')
    taslJJA=Lakefile.variables[Varname+'l'][3624:5832,:,:] #6.1-8.31
    taslDJF0=Lakefile.variables[Varname+'l'][0:1416,:,:] #1.1-2.8
    taslDJF1=Lakefile.variables[Varname+'l'][8016:8760,:,:] #12.1-12.31
    taslDJF=np.ma.concatenate((taslDJF0,taslDJF1),axis=0)
    del taslDJF0,taslDJF1; gc.collect()
    
    Ruralfile= nc.Dataset(Dir+'b.e21.SSP585UC_IndividualSoil_'+Varname+'r_Rural_Hourly_'+str(year)+'.nc')
    tasrJJA=Ruralfile.variables[Varname+'r'][3624:5832,:,:]#6.1-8.31
    tasrDJF0=Ruralfile.variables[Varname+'r'][0:1416,:,:]#1.1-2.8
    tasrDJF1=Ruralfile.variables[Varname+'r'][8016:8760,:,:]#12.1-12.31
    tasrDJF=np.ma.concatenate((tasrDJF0,tasrDJF1),axis=0)
    del tasrDJF0,tasrDJF1; gc.collect()
    
    # Calculate long-term duirnal cycles of air temperature
    taslMeanJJA=CalMeanDuirnalCycle(taslJJA);taslMeanDJF=CalMeanDuirnalCycle(taslDJF)
    tasrMeanJJA=CalMeanDuirnalCycle(tasrJJA);tasrMeanDJF=CalMeanDuirnalCycle(tasrDJF)
    
    # Only use summer results (JJA for north hemisphere and DJF for south hemisphere)
    taslMean=np.ma.concatenate((taslMeanDJF[:,0:96,:],taslMeanJJA[:,96:192,:]),axis=1)
    tasrMean=np.ma.concatenate((tasrMeanDJF[:,0:96,:],tasrMeanJJA[:,96:192,:]),axis=1)
    # Convert UTC to Local Time
    taslMeanCon=ConvertUTC2LTC(taslMean)
    tasrMeanCon=ConvertUTC2LTC(tasrMean)
    # Calculate the daytime and nighttime mean air temperatures
    [taslDay,taslNight]=CalDayNightMean(taslMeanCon)
    [tasrDay,tasrNight]=CalDayNightMean(tasrMeanCon)
    
    # stack the results from 2019 to 2023
    if year==2019:
        LakeDay=taslDay
        RuralDay=tasrDay
        LakeNight=taslNight
        RuralNight=tasrNight
    else:
        LakeDay=np.ma.concatenate((LakeDay,taslDay),axis=0)
        RuralDay=np.ma.concatenate((RuralDay,tasrDay),axis=0)
        LakeNight=np.ma.concatenate((LakeNight,taslNight),axis=0)
        RuralNight=np.ma.concatenate((RuralNight,tasrNight),axis=0)

# Calculate the five-year mean Lake-Nonlake air temperatures 
DeltaDay2019=np.nanmean(LakeDay-RuralDay,axis=0)
DeltaNight2019=np.nanmean(LakeNight-RuralNight,axis=0) 

for year in range(2096,2101):
    Lakefile= nc.Dataset(Dir+'b.e21.SSP585UC_IndividualSoil_'+Varname+'l_Lake_Hourly_'+str(year)+'.nc')
    taslJJA=Lakefile.variables[Varname+'l'][3624:5832,:,:]#6.1-8.31
    taslDJF0=Lakefile.variables[Varname+'l'][0:1416,:,:]#1.1-2.8
    taslDJF1=Lakefile.variables[Varname+'l'][8016:8760,:,:] #12.1-12.31
    taslDJF=np.ma.concatenate((taslDJF0,taslDJF1),axis=0)
    del taslDJF0,taslDJF1; gc.collect()
    
    Ruralfile= nc.Dataset(Dir+'b.e21.SSP585UC_IndividualSoil_'+Varname+'r_Rural_Hourly_'+str(year)+'.nc')
    tasrJJA=Ruralfile.variables[Varname+'r'][3624:5832,:,:]#6.1-8.31
    tasrDJF0=Ruralfile.variables[Varname+'r'][0:1416,:,:]#1.1-2.8
    tasrDJF1=Ruralfile.variables[Varname+'r'][8016:8760,:,:] #12.1-12.31
    tasrDJF=np.ma.concatenate((tasrDJF0,tasrDJF1),axis=0)
    del tasrDJF0,tasrDJF1; gc.collect()
    
    # Calculate long-term duirnal cycles of air temperature
    taslMeanJJA=CalMeanDuirnalCycle(taslJJA);taslMeanDJF=CalMeanDuirnalCycle(taslDJF)
    tasrMeanJJA=CalMeanDuirnalCycle(tasrJJA);tasrMeanDJF=CalMeanDuirnalCycle(tasrDJF)
    # Only use summer results (JJA for north hemisphere and DJF for south hemisphere)
    taslMean=np.ma.concatenate((taslMeanDJF[:,0:96,:],taslMeanJJA[:,96:192,:]),axis=1)
    tasrMean=np.ma.concatenate((tasrMeanDJF[:,0:96,:],tasrMeanJJA[:,96:192,:]),axis=1)
    # Convert UTC to Local Time
    taslMeanCon=ConvertUTC2LTC(taslMean)
    tasrMeanCon=ConvertUTC2LTC(tasrMean)
    # Calculate the daytime and nighttime mean air temperatures
    [taslDay,taslNight]=CalDayNightMean(taslMeanCon)
    [tasrDay,tasrNight]=CalDayNightMean(tasrMeanCon)
    
    # stack the results from 2096 to 2100
    if year==2096:
        LakeDay=taslDay
        RuralDay=tasrDay
        LakeNight=taslNight
        RuralNight=tasrNight
    else:
        LakeDay=np.ma.concatenate((LakeDay,taslDay),axis=0)
        RuralDay=np.ma.concatenate((RuralDay,tasrDay),axis=0)
        LakeNight=np.ma.concatenate((LakeNight,taslNight),axis=0)
        RuralNight=np.ma.concatenate((RuralNight,tasrNight),axis=0)
        
# Calculate the five-year mean Lake-Nonlake air temperatures 
DeltaDay2096=np.nanmean(LakeDay-RuralDay,axis=0)
DeltaNight2096=np.nanmean(LakeNight-RuralNight,axis=0)

LatMeanDeltaDay2019=Cal_latitudinal_mean(DeltaDay2019)  
LatMeanDeltaNight2019=Cal_latitudinal_mean(DeltaNight2019)
LatMeanDayDelta=Cal_latitudinal_mean(DeltaDay2096 - DeltaDay2019)  
LatMeanNightDelta=Cal_latitudinal_mean(DeltaNight2096 - DeltaNight2019) 

# Plot Figure
fig = plt.figure(figsize=(16, 8.6),constrained_layout=True)
widths = [5, 0.9,1,5, 0.9]
heights = [1,1]
spec5 = fig.add_gridspec(ncols=5, nrows=2, width_ratios=widths,
                          height_ratios=heights)

ax1 = fig.add_subplot(spec5[0, 0])
PlotContour(DeltaDay2019, -2.4, 2.7, 0.3,'Summer daytime Lake-Nonlake Ta (2019-2023)','rainbow')  #Spectral jet seismic
ax1.text(0,-114,'(a)', horizontalalignment='center',verticalalignment='center', size=14)
ax1 = fig.add_subplot(spec5[0, 1])
ax1.yaxis.tick_right()
ax1.tick_params(axis='both',direction='in',labelsize=15)
ax1.yaxis.set_label_position("right")
PlotLatitudinalMean(LatMeanDeltaDay2019,'r','ΔTa')

ax2 = fig.add_subplot(spec5[0, 3])
cs1=PlotContour(DeltaNight2019,-2.4, 2.7, 0.3, 'Summer nighttime Lake-Nonlake Ta (2019-2023)','rainbow') 
ax2.text(0,-114,'(b)', horizontalalignment='center',verticalalignment='center', size=14)
ax2 = fig.add_subplot(spec5[0, 4])
ax2.yaxis.tick_right()
ax2.tick_params(axis='both',direction='in',labelsize=15)
ax2.yaxis.set_label_position("right")
PlotLatitudinalMean(LatMeanDeltaNight2019,'r','ΔTa')

# plt.subplots_adjust(top=0.875,bottom=0.1,left=0.05,right=0.95,hspace=0.3,wspace=0.33)   

# cbar_ax = fig.add_axes([0.5-0.23, 0.10, 0.46, 0.02])
# cbar = fig.colorbar(cs, cax=cbar_ax, orientation="horizontal", extend="both")
# cbar.ax.tick_params(labelsize=11) 
# cbar.ax.set_xticklabels(levellabel) 
# plt.text(5.4, 0.10, '(%)', horizontalalignment='center',verticalalignment='center', size=13)

ax = fig.add_subplot(spec5[1, 0])
PlotContour(DeltaDay2096 - DeltaDay2019,-1.4, 1.2, 0.1,  'The 2096-2100 minus 2019-2023 summer daytime Lake-Nonlake Ta','jet')   #coolwarm 'seismic'
ax.text(0,-114,'(c)', horizontalalignment='center',verticalalignment='center', size=14)

ax = fig.add_subplot(spec5[1, 1])
ax.yaxis.tick_right()
ax.tick_params(axis='both',direction='in',labelsize=15)
ax.yaxis.set_label_position("right")
PlotLatitudinalMean(LatMeanDayDelta,'salmon','2096-2100 minus 2019-2023 ΔTa')

ax = fig.add_subplot(spec5[1, 3])
cs2=PlotContour(DeltaNight2096 - DeltaNight2019, -1.4, 1.2, 0.1,  'The 2096-2100 minus 2019-2023 summer nighttime Lake-Nonlake Ta','jet')   
ax.text(0,-114,'(d)', horizontalalignment='center',verticalalignment='center', size=14)
ax = fig.add_subplot(spec5[1, 4])
ax.yaxis.tick_right()
ax.tick_params(axis='both',direction='in',labelsize=15)
ax.yaxis.set_label_position("right")
PlotLatitudinalMean(LatMeanNightDelta,'salmon','2096-2100 minus 2019-2023 ΔTa')
plt.subplots_adjust(top=0.935,
bottom=0.185,
left=0.11,
right=0.9,
hspace=0.49,
wspace=0.115)  

# Lower Cbar
cbar_ax2 = fig.add_axes([0.5-0.23, 0.09, 0.46, 0.02])
cbar2 = fig.colorbar(cs2, cax=cbar_ax2, orientation="horizontal", extend="both")
cbar2.ax.tick_params(labelsize=11) 
cbar_ax2.text(1.3,0,'(°C)', horizontalalignment='center',verticalalignment='center', size=14)

# Upper Cbar
cbar_ax1 = fig.add_axes([0.5-0.23, 0.54, 0.46, 0.02])
cbar1 = fig.colorbar(cs1, cax=cbar_ax1, orientation="horizontal", extend="both")
cbar1.ax.tick_params(labelsize=11) 
cbar_ax1.text(2.8,0,'(°C)', horizontalalignment='center',verticalalignment='center', size=14)

# plt.savefig('E:\\Keer_work\\DataPaper\\Figure7.png', dpi=450)
# plt.close()